# -*- coding: utf-8 -*-
"""DCGAN.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1wAozkOLN_JTEqLQkenANxxJmO7PfAUa0
"""

import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
## define a function for the generator:
def make_generator_network(input_size=20,num_hidden_layers=1,num_hidden_units=100,num_output_units=784):
  model = nn.Sequential()
  for i in range(num_hidden_layers):
    model.add_module(f'fc_g{i}',nn.Linear(input_size, num_hidden_units))
    model.add_module(f'relu_g{i}', nn.LeakyReLU())
    input_size = num_hidden_units
  model.add_module(f'fc_g{num_hidden_layers}',nn.Linear(input_size, num_output_units))
  model.add_module('tanh_g', nn.Tanh())
  return model
## define a function for the discriminator:
def make_discriminator_network(input_size,num_hidden_layers=1,num_hidden_units=100,num_output_units=1):
  model = nn.Sequential()
  for i in range(num_hidden_layers):
    model.add_module(f'fc_d{i}',nn.Linear(input_size, num_hidden_units, bias=False))
    model.add_module(f'relu_d{i}', nn.LeakyReLU())
    model.add_module('dropout', nn.Dropout(p=0.5))
    input_size = num_hidden_units
  model.add_module(f'fc_d{num_hidden_layers}',nn.Linear(input_size, num_output_units))
  model.add_module('sigmoid', nn.Sigmoid())
  return model

import torch
image_size = (28, 28)
z_size = 20
gen_hidden_layers = 1
gen_hidden_size = 100
disc_hidden_layers = 1
disc_hidden_size = 100
torch.manual_seed(1)
gen_model = make_generator_network(input_size=z_size,num_hidden_layers=gen_hidden_layers,num_hidden_units=gen_hidden_size,num_output_units=np.prod(image_size))
print(gen_model)
disc_model = make_discriminator_network(input_size=np.prod(image_size),num_hidden_layers=disc_hidden_layers,num_hidden_units=disc_hidden_size)
print(disc_model)

import torchvision
from torchvision import transforms
image_path = './'
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5), std=(0.5)),])
mnist_dataset = torchvision.datasets.MNIST(root=image_path, train=True,transform=transform, download=True)
example, label = next(iter(mnist_dataset))
print(f'Min: {example.min()} Max: {example.max()}')
print(example.shape)

def create_noise(batch_size, z_size, mode_z):
  if mode_z == 'uniform':
    input_z = torch.rand(batch_size, z_size)*2 - 1
  elif mode_z == 'normal':
    input_z = torch.randn(batch_size, z_size)
  return input_z

from torch.utils.data import DataLoader
batch_size = 32
dataloader = DataLoader(mnist_dataset, batch_size, shuffle=False)
input_real, label = next(iter(dataloader))
input_real = input_real.view(batch_size, -1)
torch.manual_seed(1)
mode_z = 'uniform' # 'uniform' vs. 'normal'
input_z = create_noise(batch_size, z_size, mode_z)
print('input-z -- shape:', input_z.shape)
print('input-real -- shape:', input_real.shape)
g_output = gen_model(input_z)
print('Output of G -- shape:', g_output.shape)
d_proba_real = disc_model(input_real)
d_proba_fake = disc_model(g_output)
print('Disc. (real) -- shape:', d_proba_real.shape)
print('Disc. (fake) -- shape:', d_proba_fake.shape)

loss_fn = nn.BCELoss()
## Loss for the Generator
g_labels_real = torch.ones_like(d_proba_fake)
g_loss = loss_fn(d_proba_fake, g_labels_real)
print(f'Generator Loss: {g_loss:.4f}')
## Loss for the Discriminator
d_labels_real = torch.ones_like(d_proba_real)
d_labels_fake = torch.zeros_like(d_proba_fake)
d_loss_real = loss_fn(d_proba_real, d_labels_real)
d_loss_fake = loss_fn(d_proba_fake, d_labels_fake)
print(f'Discriminator Losses: Real {d_loss_real:.4f} Fake {d_loss_fake:.4f}')

import torch
batch_size = 64
torch.manual_seed(1)
np.random.seed(1)
mnist_dl = DataLoader(mnist_dataset, batch_size=batch_size,shuffle=True, drop_last=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gen_model = make_generator_network(input_size=z_size,num_hidden_layers=gen_hidden_layers,num_hidden_units=gen_hidden_size,num_output_units=np.prod(image_size)).to(device)
disc_model = make_discriminator_network(input_size=np.prod(image_size),num_hidden_layers=disc_hidden_layers,num_hidden_units=disc_hidden_size).to(device)
loss_fn = nn.BCELoss()
g_optimizer = torch.optim.Adam(gen_model.parameters())
d_optimizer = torch.optim.Adam(disc_model.parameters())

## Train the discriminator
def d_train(x):
  disc_model.zero_grad()
  # Train discriminator with a real batch
  batch_size = x.size(0)
  x = x.to(device)
  d_labels_real = torch.ones(batch_size, 1, device=device)
  d_proba_real = disc_model(x)
  d_loss_real = loss_fn(d_proba_real, d_labels_real)
  # Train discriminator on a fake batch
  input_z = create_noise(batch_size, z_size, mode_z).to(device)
  g_output = gen_model(input_z)
  d_proba_fake = disc_model(g_output.detach()) # Detach to avoid backpropagating through G
  d_labels_fake = torch.zeros(batch_size, 1, device=device)
  d_loss_fake = loss_fn(d_proba_fake, d_labels_fake)
  # gradient backprop & optimize ONLY D's parameters
  d_loss = d_loss_real + d_loss_fake
  d_loss.backward()
  d_optimizer.step()
  return d_loss.data.item(), d_proba_real.detach(), d_proba_fake.detach()

## Train the generator
def g_train():
  gen_model.zero_grad()
  batch_size = fixed_z.size(0) # Use fixed_z for consistent batch size in training
  input_z = create_noise(batch_size, z_size, mode_z).to(device)
  g_labels_real = torch.ones(batch_size, 1, device=device)
  g_output = gen_model(input_z)
  d_proba_fake = disc_model(g_output)
  g_loss = loss_fn(d_proba_fake, g_labels_real)
  # gradient backprop & optimize ONLY G's parameters
  g_loss.backward()
  g_optimizer.step()
  return g_loss.data.item()

fixed_z = create_noise(batch_size, z_size, mode_z).to(device)
def create_samples(g_model, input_z):
  with torch.no_grad(): # Disable gradient calculation for sampling
    g_output = g_model(input_z)
    # images = torch.reshape(g_output, (batch_size, *image_size)) # No reshape needed for DCGAN output
    return (g_output + 1) / 2.0
epoch_samples = []
all_d_losses = []
all_g_losses = []
all_d_real = []
all_d_fake = []
num_epochs = 100
mode_z = 'normal' # Define mode_z here or use the one defined earlier

import itertools
fig = plt.figure(figsize=(16, 6))
## Plotting the losses
ax = fig.add_subplot(1, 2, 1)
plt.plot(all_g_losses, label='Generator loss')
half_d_losses = [all_d_loss/2 for all_d_loss in all_d_losses]
plt.plot(half_d_losses, label='Discriminator loss')
plt.legend(fontsize=20)
ax.set_xlabel('Iteration', size=15)
ax.set_ylabel('Loss', size=15)
## Plotting the outputs of the discriminator
ax = fig.add_subplot(1, 2, 2)
plt.plot(all_d_real, label=r'Real: $D(\mathbf{x})$')
plt.plot(all_d_fake, label=r'Fake: $D(G(\mathbf{z}))$')
plt.legend(fontsize=20)
ax.set_xlabel('Iteration', size=15)
ax.set_ylabel('Discriminator output', size=15)
plt.show()

selected_epochs = [1, 2, 4, 10, 50, 100]
fig = plt.figure(figsize=(10, 14))
for i,e in enumerate(selected_epochs):
  for j in range(5):
    ax = fig.add_subplot(6, 5, i*5+j+1)
    ax.set_xticks([])
    ax.set_yticks([])
    if j == 0:
      ax.text(-0.06, 0.5, f'Epoch {e}',rotation=90, size=18, color='red',horizontalalignment='right',verticalalignment='center',transform=ax.transAxes)
    image = epoch_samples[e-1][j]
    ax.imshow(image, cmap='gray_r')
plt.show()

def make_generator_network(input_size, n_filters):
  model = nn.Sequential(
    nn.ConvTranspose2d(input_size, n_filters*4, 4, 1, 0, bias=False),
    nn.BatchNorm2d(n_filters*4),
    nn.LeakyReLU(0.2),
    nn.ConvTranspose2d(n_filters*4, n_filters*2, 3, 2, 1, bias=False),
    nn.BatchNorm2d(n_filters*2),
    nn.LeakyReLU(0.2),
    nn.ConvTranspose2d(n_filters*2, n_filters, 4, 2, 1, bias=False),
    nn.BatchNorm2d(n_filters),
    nn.LeakyReLU(0.2),
    nn.ConvTranspose2d(n_filters, 1, 4, 2, 1, bias=False),
    nn.Tanh()
  )
  return model
class Discriminator(nn.Module):
  def __init__(self, n_filters):
    super().__init__()
    self.network = nn.Sequential(
    nn.Conv2d(1, n_filters, 4, 2, 1, bias=False),
    nn.LeakyReLU(0.2),
    nn.Conv2d(n_filters, n_filters*2, 4, 2, 1, bias=False),
    nn.BatchNorm2d(n_filters * 2),
    nn.LeakyReLU(0.2),
    nn.Conv2d(n_filters*2, n_filters*4, 3, 2, 1, bias=False),
    nn.BatchNorm2d(n_filters*4),
    nn.LeakyReLU(0.2),
    nn.Conv2d(n_filters*4, 1, 4, 1, 0, bias=False),
    nn.Sigmoid()
  )
  def forward(self, input):
    output = self.network(input)
    return output.view(-1, 1).squeeze(0)
z_size = 100
image_size = (28, 28)
n_filters = 32
gen_model = make_generator_network(z_size, n_filters).to(device)
print(gen_model)
disc_model = Discriminator(n_filters).to(device)
print(disc_model)

loss_fn = nn.BCELoss()
g_optimizer = torch.optim.Adam(gen_model.parameters(), 0.0003)
d_optimizer = torch.optim.Adam(disc_model.parameters(), 0.0002)

def create_noise(batch_size, z_size, mode_z):
  if mode_z == 'uniform':
    input_z = torch.rand(batch_size, z_size, 1, 1)*2 - 1
  elif mode_z == 'normal':
    input_z = torch.randn(batch_size, z_size, 1, 1)
  return input_z

def d_train(x):
  disc_model.zero_grad()
  # Train discriminator with a real batch
  batch_size = x.size(0)
  x = x.to(device)
  d_labels_real = torch.ones(batch_size, 1, device=device)
  d_proba_real = disc_model(x)
  d_loss_real = loss_fn(d_proba_real, d_labels_real)
  # Train discriminator on a fake batch
  input_z = create_noise(batch_size, z_size, mode_z).to(device)
  g_output = gen_model(input_z)
  d_proba_fake = disc_model(g_output)
  d_labels_fake = torch.zeros(batch_size, 1, device=device)
  d_loss_fake = loss_fn(d_proba_fake, d_labels_fake)
  # gradient backprop & optimize ONLY D's parameters
  d_loss = d_loss_real + d_loss_fake
  d_loss.backward()
  d_optimizer.step()
  return d_loss.data.item(), d_proba_real.detach(), d_proba_fake.detach()

fixed_z = create_noise(batch_size, z_size, mode_z).to(device)
epoch_samples = []
all_d_losses = []
all_g_losses = []
all_d_real = []
all_d_fake = []
torch.manual_seed(1)
for epoch in range(1, num_epochs+1):
  d_losses, g_losses = [], []
  d_vals_real, d_vals_fake = [], []
  gen_model.train()
  for i, (x, _) in enumerate(mnist_dl):
    d_loss, d_proba_real, d_proba_fake = d_train(x)
    d_losses.append(d_loss)
    d_vals_real.append(d_proba_real.mean().cpu())
    d_vals_fake.append(d_proba_fake.mean().cpu())

    g_loss = g_train()
    g_losses.append(g_loss)

  all_d_losses.append(torch.tensor(d_losses).mean())
  all_g_losses.append(torch.tensor(g_losses).mean())
  all_d_real.append(torch.tensor(d_vals_real).mean())
  all_d_fake.append(torch.tensor(d_vals_fake).mean())

  print(f'Epoch {epoch:03d} | Avg Losses >>'f' G/D {all_g_losses[-1]:.4f}'f'/{all_d_losses[-1]:.4f}'f' [D-Real: {all_d_real[-1]:.4f}'f' D-Fake: {all_d_fake[-1]:.4f}]')

  gen_model.eval()
  epoch_samples.append(create_samples(gen_model, fixed_z).detach().cpu().numpy())

selected_epochs = [1, 2, 4, 10, 50, 100]
fig = plt.figure(figsize=(10, 14))
for i,e in enumerate(selected_epochs):
  for j in range(5):
    ax = fig.add_subplot(6, 5, i*5+j+1)
    ax.set_xticks([])
    ax.set_yticks([])
    if j == 0:
      ax.text(-0.06, 0.5, f'Epoch {e}', rotation=90, size=18, color='red', horizontalalignment='right', verticalalignment='center', transform=ax.transAxes)
    image = epoch_samples[e-1][j].squeeze() # Remove the channel dimension
    ax.imshow(image, cmap='gray_r')
plt.show()